{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) Microsoft Corporation. All rights reserved.  \n",
    "Licensed under the MIT License.\n",
    "\n",
    "# Inference TensorFlow Bert Model for High Performance in ONNX Runtime #\n",
    "\n",
    "This tutorial shows how to convert the original Tensorflow Bert model into ONNX, and then inference it with ONNX Runtime for high performance with transformer optimization. In the following sections, we are going to use the Bert model trained with Stanford Question Answering Dataset (SQuAD) dataset as an example. Bert SQuAD model is used in question answering scenarios, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 - Convert a TensorFlow Bert model to ONNX\n",
    "\n",
    "To start with, we need an ONNX Bert model. We can get an ONNX Bert model by converting from Tensorflow.\n",
    "\n",
    "Follow instructions in [Converting a Tensorflow Bert model to ONNX](https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/BertTutorial.ipynb) from step 1 through 6 to train and export Tensorflow Bert models to ONNX.  \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 - Optimize Model\n",
    "After we get the ONNX model, apply the graph transformer script on it to get an optimized graph. \n",
    "\n",
    "### Download the Bert Optimization Script"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir bert_op_scripts\n",
    "!wget -O ./bert_op_scripts/bert_model_optimization.py https://raw.githubusercontent.com/microsoft/onnxruntime/master/onnxruntime/python/tools/bert/bert_model_optimization.py\n",
    "!wget -O ./bert_op_scripts/BertOnnxModelTF.py https://raw.githubusercontent.com/microsoft/onnxruntime/master/onnxruntime/python/tools/bert/BertOnnxModelTF.py\n",
    "!wget -O ./bert_op_scripts/BertOnnxModel.py https://raw.githubusercontent.com/microsoft/onnxruntime/master/onnxruntime/python/tools/bert/BertOnnxModel.py\n",
    "!wget -O ./bert_op_scripts/OnnxModel.py https://raw.githubusercontent.com/microsoft/onnxruntime/master/onnxruntime/python/tools/bert/OnnxModel.py\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the `bert_model_optimization.py` with `--framework tensorflow` option to optimize the converted model. Other notable options to use are \n",
    "- `--gpu_only`: allow half-precision float for better performance.\n",
    "- `--input_int32`: Use int32 tensors instead of default int64 as input to avoid un-necessary Cast nodes and get better performance.\n",
    "- `--float16`: Use float16 tensors instead of default float32 as input to enable half precision floats. Recommended for NVidia GPU with Tensor Core like V100 and T4. For older GPUs, float32 is likely faster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Below are three examples to run bert_model_optimization.py. Choose one according to your needs and adjust --input\n",
    "# --output path names as necessary.\n",
    "\n",
    "# For CPU\n",
    "!python bert_op_scripts/bert_model_optimization.py --input <bert.onnx> --output <bert_cpu.onnx> --framework tensorflow\n",
    "\n",
    "# # For inferences under NVidia GPU with Tensor Core like V100 and T4\n",
    "# !python bert_op_scripts/bert_model_optimization.py --input <bert.onnx> --output <bert_gpu_fp16.onnx> --framework tensorflow --gpu_only –float16\n",
    "\n",
    "# # For inferences under other NVidia GPUs except V100 and T4\n",
    "# !python bert_op_scripts/bert_model_optimization.py --input <bert.onnx> --output <bert_gpu_fp32.onnx> --framework tensorflow --gpu_only\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 - Inference the Optimized Model with ONNX Runtime\n",
    "\n",
    "#### Install ONNX Runtime\n",
    "Install the latest ONNX Runtime if you haven't done so already. \n",
    "\n",
    "Install one `onnxruntime` python build. Choose to install `onnxruntime` to use CPU features, or `onnxruntime-gpu` to enjoy GPU execution providers such as CUDA. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install ONNX Runtime for CPU\n",
    "!{sys.executable} -m pip install -U onnxruntime\n",
    "\n",
    "## Alternatively, install onnxruntime for GPU\n",
    "# !{sys.executable} -m pip install -U onnxruntime-gpu"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we're ready to use ONNX Runtime to do inference on the optimized model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnxruntime as rt  \n",
    "import numpy as np\n",
    "import time\n",
    "\n",
    "sess_options = rt.SessionOptions()\n",
    "\n",
    "# Set graph optimization level to ORT_ENABLE_EXTENDED to enable bert optimization.\n",
    "sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED\n",
    "\n",
    "session = rt.InferenceSession(\"./bert_squad_op.onnx\", sess_options)\n",
    "\n",
    "# evaluate the model\n",
    "# Generate dummy inputs to the model. Adjust if neccessary\n",
    "inputs = {\n",
    "    'input_ids:0':   np.random.randint(0, 256, size=[1, 256], dtype=np.int64), # list of numerical ids for the tokenised text\n",
    "    'segment_ids:0': np.ones(shape=[1, 256], dtype=np.int64),        # dummy list of ones\n",
    "    'input_mask:0':  np.ones(shape=[1, 256], dtype=np.int64),        # dummy list of ones\n",
    "    'unique_ids_raw_output___9:0': np.arange(0, 256, dtype=np.int64)\n",
    "}\n",
    "\n",
    "start = time.time()\n",
    "# Run the optimized model with inputs\n",
    "output_names = ['unstack:1', 'unstack:0', 'unique_ids:0']\n",
    "res = session.run(output_names, inputs) \n",
    "end = time.time()\n",
    "print(\"ONNX Runtime Inference time: \", end - start)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the perf numbers from TensorFlow model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "# Get input and output keys\n",
    "with tf.Session(graph=tf.Graph()) as sess:\n",
    "    # Load TensorFlow saved model\n",
    "    metagraph = tf.saved_model.loader.load(sess, \n",
    "                                           [tf.saved_model.tag_constants.SERVING], \n",
    "                                           \"./saved_model\")\n",
    "    \n",
    "    # Get the input/output names for the saved model.\n",
    "    inputs_mapping = dict(metagraph.signature_def['serving_default'].inputs)\n",
    "    input_names = [inputs_mapping[i].name for i in inputs_mapping.keys()]\n",
    "    print(\"input names \", input_names)\n",
    "    start = time.time()\n",
    "    out = sess.run(output_names, \n",
    "                   {input_names[0]: inputs[\"unique_ids_raw_output___9:0\"],\n",
    "                     input_names[1]: inputs[\"segment_ids:0\"],\n",
    "                     input_names[2]: inputs[\"input_ids:0\"], \n",
    "                     input_names[3]: inputs[\"input_mask:0\"]})\n",
    "    end = time.time()\n",
    "    print(\"\\n\")\n",
    "    print(\"Tensorflow Inference time: \", end - start)\n",
    "    print(\"\\n\")\n",
    "    print(\"***** Verifying correctness *****\")\n",
    "    for i in range(3):\n",
    "        print('Tensorflow and ONNX Runtime matching numbers:', np.allclose(res[i], out[i], rtol=1e-05, atol=1e-02))\n",
    "\n",
    "    sess.close()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
